Skip to content

perf: optimize LTX2 inference latency and implement granular TPU profiling#389

Open
mbohlool wants to merge 1 commit intomainfrom
mehdy_perf
Open

perf: optimize LTX2 inference latency and implement granular TPU profiling#389
mbohlool wants to merge 1 commit intomainfrom
mehdy_perf

Conversation

@mbohlool
Copy link
Copy Markdown
Collaborator

Optimize LTX2 inference latency and implement granular TPU profiling

Description

This PR introduces critical performance optimizations and comprehensive profiling infrastructure for the LTX2 video generation pipeline on TPU hardware.

Key Changes

1. Inference Parallelism Optimization (ltx2_video.yml)

Switched from ICI Context Parallelism (ici_context_parallelism: 1) to ICI Data Parallelism (ici_data_parallelism: -1).

  • Impact: Since Classifier-Free Guidance (CFG) generates independent batch items, DP acts "embarrassingly parallel" for inference. This requires zero cross-core communication, completely bypassing the massive All-Gather ICI bottlenecks caused by sequence-sharding.

2. Granular XLA Profiling Annotations (ltx2_pipeline.py)

Injected jax.named_scope wrappers around all major TPU-bound compute blocks (Connectors, Video VAE, Audio VAE, Vocoder).

  • Impact: This prevents massive operations from appearing as unlabeled blobs in the Cloud TPU Profiler (xprof), enabling accurate FLOPs tracking and roofline analysis for individual components outside of the main denoising loop.

3. Execution Timing & Benchmarking (generate_ltx2.py & ltx2_pipeline.py)

Added synchronous jax.block_until_ready() wrappers at the boundaries of major pipeline stages to accurately measure execution time without asynchronous JAX dispatch artifacts.

  • Impact: Restructured the generation script into a 3-pass strategy (Warmup, Generation, Profiling) using skip_first_n_steps_for_profiler: 0 to completely isolate true runtime execution latency from initial JIT compilation overhead.

@mbohlool mbohlool requested a review from entrpn as a code owner April 23, 2026 00:20
@github-actions
Copy link
Copy Markdown

- Switched to DP (ici_data_parallelism: -1) in ltx2 config to bypass ICI communication overhead during inference.
- Added `jax.named_scope` around connectors and VAE blocks for accurate xprof trace attribution.
- Added synchronous `perf_counter` wrappers in the pipeline to measure true stage latencies.
- Implemented a 3-pass (warmup, run, profile) generation loop in `generate_ltx2.py` to isolate JIT compilation time and better profiling.
@Perseus14
Copy link
Copy Markdown
Collaborator

@mbohlool Could you add a table with the latency gain (single video and amortized throughput) of this change with the baseline (main)?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants